{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train MobileNet classifier using Weight Imprinting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_folder = '/home/pi/dataset/ttt-boxes-smalldb/'\n",
    "test_ratio = 0.1\n",
    "\n",
    "output_basename = 'ttt-boxes'\n",
    "output_model = f'{output_basename}.tflite'\n",
    "output_labelmap = f'{output_basename}.txt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pretrained_model_path = '/home/pi/models/mobilenet_v1_1.0_224_l2norm_quant_edgetpu.tflite'\n",
    "keep_classes = False\n",
    "input_shape = (224, 224)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load pre-trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from edgetpu.learn.imprinting.engine import ImprintingEngine\n",
    "\n",
    "train_engine = ImprintingEngine(pretrained_model_path, keep_classes=keep_classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load train/test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "train_set, test_set = {}, {}\n",
    "labels_map = {}\n",
    "\n",
    "ci = 0\n",
    "\n",
    "for category in os.listdir(data_folder):\n",
    "    category_dir = os.path.join(data_folder, category)\n",
    "    if not os.path.isdir(category_dir):\n",
    "        continue\n",
    "\n",
    "    images = [\n",
    "        os.path.join(category_dir, f) \n",
    "        for f in os.listdir(category_dir)\n",
    "        if os.path.isfile(os.path.join(category_dir, f))\n",
    "    ]\n",
    "\n",
    "    k = max(int(test_ratio * len(images)), 1)\n",
    "\n",
    "    test_set[category] = images[:k]\n",
    "    train_set[category] = images[k:]\n",
    "    \n",
    "    labels_map[ci] = category\n",
    "    ci += 1\n",
    "\n",
    "for c in train_set.keys():\n",
    "    print(f'Label {c}: train imgs {len(train_set[c])} - test imgs {len(test_set[c])}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "from PIL import Image\n",
    "\n",
    "\n",
    "def prepare_image(image_list, input_shape):\n",
    "    ret = []\n",
    "\n",
    "    for filename in image_list:\n",
    "        with Image.open(filename) as img:\n",
    "            img = img.convert('RGB')\n",
    "            img = img.resize(input_shape, Image.NEAREST)\n",
    "            ret.append(np.asarray(img).flatten())\n",
    "    return np.array(ret)\n",
    "\n",
    "print('Processing train images...')\n",
    "train_data = [prepare_image(imgs, input_shape) for imgs in train_set.values()]\n",
    "print('Done!')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Start training...')\n",
    "train_engine.train_all(train_data)\n",
    "print('Done!')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_engine.save_model(output_model)\n",
    "\n",
    "with open(output_labelmap, 'w') as f:\n",
    "    for label_id, label in labels_map.items():\n",
    "        f.write(f'{label_id} {label}\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate our model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from edgetpu.classification.engine import ClassificationEngine\n",
    "\n",
    "test_engine = ClassificationEngine(output_model)\n",
    "\n",
    "total = 0\n",
    "nb_images = 0\n",
    "\n",
    "\n",
    "for category, image_list in test_set.items():   \n",
    "    correct = 0\n",
    "    \n",
    "    for img_name in image_list:\n",
    "        img = Image.open(os.path.join(data_folder, category, img_name))\n",
    "        result = test_engine.classify_with_image(img, threshold=0.1, top_k=1)[0]\n",
    "        \n",
    "        if labels_map[result[0]] == category:\n",
    "            correct += 1\n",
    "\n",
    "    print(f'Evaluating category \"{category}\": {correct}/{len(image_list)}')\n",
    "    \n",
    "    total += correct\n",
    "    nb_images += len(image_list)\n",
    "            \n",
    "print(f'Total {total}/{nb_images}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
